#include "caffe2/core/context.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/operator.h"

#include <THCAllocator.h>
#include <THCStorage.h>
#include <THCTensor.h>

#include <THCUNN.h>

namespace caffe2 {

namespace {

THCState* getTHCState() {
  // TODO don't leak the THCState. We only have as many threads as
  // e.g. the number of AsyncDAGNet worker threads, which is small
  // (O(numGPUs)).
  static thread_local THCState* state = nullptr;
  if (!state) {
    state = new THCState();
    THCudaInit(state);
    CHECK_NOTNULL(state);
  }
  return state;
}

struct THCudaTensorDeleter {
  explicit THCudaTensorDeleter(THCState* state)
      : state_(CHECK_NOTNULL(state)) {}
  void operator()(THCudaTensor* th) {
    THCudaTensor_free(state_, th);
  }

  THCState* state_{nullptr};
};

using UniqueTHCudaTensor = std::unique_ptr<THCudaTensor, THCudaTensorDeleter>;

THCState* thnnState(CUDAContext* context) {
  THCState* state = getTHCState();
  THCStream* stream = THCState_getStream(state);
  // TODO - swap these back after we're done before we handle
  // deletion.
  // TODO - handle proper destroy of existing handle
  // (if not already caffe2 set handle)
  stream->stream = context->cuda_stream();

  // TODO - destroy the current handle
  int device;
  THCudaCheck(cudaGetDevice(&device));
  int blasHandleIndex = THCState_getCurrentBlasHandleIndex(state);
  THCState_getDeviceBlasHandle(state, device, blasHandleIndex); // to reserve
  THCCudaResourcesPerDevice* res = &(state->resourcesPerDevice[device]);
  res->blasHandles[blasHandleIndex - 1] = context->cublas_handle();
  return state;
}

UniqueTHCudaTensor aliasFromTensorCUDA(
    CUDAContext* context,
    TensorCUDA* tensor) {
  auto* state = thnnState(context);
  if (!tensor->ndim()) {
    return UniqueTHCudaTensor(
        THCudaTensor_new(state), THCudaTensorDeleter(state));
  }
  THLongStorage* thshape = THLongStorage_newWithSize(tensor->ndim());
  for (int i = 0; i < tensor->ndim(); ++i) {
    THLongStorage_set(thshape, i, tensor->dim(i));
  }
  THCudaStorage* storage = THCudaStorage_newWithData(
      state, tensor->mutable_data<float>(), tensor->size());
  THCudaStorage_clearFlag(state, storage, TH_STORAGE_FREEMEM);
  auto* th =
      THCudaTensor_newWithStorage(state, storage, 0, thshape, nullptr);
  THCudaStorage_free(state, storage);
  THLongStorage_free(thshape);
  CAFFE_ENFORCE_EQ(
      THCudaTensor_storage(state, th)->data,
      tensor->mutable_data<float>());
  return UniqueTHCudaTensor(th, THCudaTensorDeleter(state));
}

void copyToTensorCUDA(
    CUDAContext* context,
    UniqueTHCudaTensor th,
    TensorCUDA* tensor) {
  auto* state = thnnState(context);
  // As contiguous
  th = UniqueTHCudaTensor(
      THCudaTensor_newContiguous(state, th.get()),
      THCudaTensorDeleter(state));
  const auto dims = std::vector<TIndex>(
      th->size, th->size + THCudaTensor_nDimension(state, th.get()));
  auto* storage = THCudaTensor_storage(state, th.get());
  // Short-circuit if we never reallocated in TH
  if (dims == tensor->dims() && storage->data == tensor->data<float>()) {
    THCudaStorage_clearFlag(state, storage, TH_STORAGE_FREEMEM);
    return;
  }

  tensor->Resize(dims);
  context->Copy<float, CUDAContext, CUDAContext>(
      tensor->size(), storage->data, tensor->mutable_data<float>());
}

// _Everything_ below here can be autogenerated with the TBD
// THNN/THCUNN schema. This is just a proof of concept.

class THNNELUCUDAOp final : public Operator<CUDAContext> {
 public:
  USE_OPERATOR_FUNCTIONS(CUDAContext);
  using Operator<CUDAContext>::Operator;

  bool RunOnDevice() override {
    // TODO - we can autogenerate this from a schema.
    auto* state = thnnState(&context_);
    auto X = aliasFromTensorCUDA(&context_, const_cast<TensorCUDA*>(&Input(0)));
    auto Y = aliasFromTensorCUDA(&context_, Output(0));
    THNN_CudaELU_updateOutput(
        state,
        X.get(),
        Y.get(),
        GetSingleArgument<float>("alpha", 1.0),
        &Input(0) == Output(0));
    copyToTensorCUDA(&context_, std::move(Y), Output(0));
    return true;
  }
};

class THNNELUCUDAGradientOp final : public Operator<CUDAContext> {
 public:
  USE_OPERATOR_FUNCTIONS(CUDAContext);
  using Operator<CUDAContext>::Operator;

  bool RunOnDevice() override {
    // TODO - we can autogenerate this from a schema.
    auto* state = thnnState(&context_);
    auto X = aliasFromTensorCUDA(&context_, const_cast<TensorCUDA*>(&Input(0)));
    auto Y = aliasFromTensorCUDA(&context_, const_cast<TensorCUDA*>(&Input(1)));
    auto dY =
        aliasFromTensorCUDA(&context_, const_cast<TensorCUDA*>(&Input(2)));
    auto dX = aliasFromTensorCUDA(&context_, Output(0));
    THNN_CudaELU_updateGradInput(
        state,
        X.get(),
        dY.get(),
        dX.get(),
        Y.get(),
        GetSingleArgument<float>("alpha", 1.0),
        &Input(2) == Output(0) /* inplace */);
    copyToTensorCUDA(&context_, std::move(dX), Output(0));
    return true;
  }
};

REGISTER_CUDA_OPERATOR_WITH_ENGINE(ELU, THNN, THNNELUCUDAOp);
REGISTER_CUDA_OPERATOR_WITH_ENGINE(ELUGradient, THNN, THNNELUCUDAGradientOp);
}
}
